# encoding: utf-8
import os.path as op
import os
import numpy as np

project_path = op.dirname(op.realpath(__file__))
wm_raw_data_dir = op.join(project_path, 'CVPR2017_WM_full')
sk_raw_data_dir = op.join(project_path, 'CVPR2017_SK_full')
image_dir = op.join(project_path, 'imageNet_images')
data_dir = op.join(project_path, 'data')


class Config_Generative_Model:
    def __init__(self):
        # project parameters
        self.seed = 2022
        self.root_path = project_path
        self.output_path = op.join(self.root_path, 'exps')
        self.pretrain_gm_path = os.path.join(self.root_path, 'pretrains')

        self.dataset = 'WM'
        self.eeg_encoder_path = None
        self.imagenet_path = image_dir
        self.img_size = 512

        np.random.seed(self.seed)
        # finetune parameters
        self.batch_size = 2
        self.lr = 5e-4
        self.num_epoch = 500
        self.precision = 32
        self.accumulate_grad = 1
        self.crop_ratio = 0.2
        self.global_pool = False
        self.use_time_cond = True
        self.clip_tune = False  # False
        self.cls_tune = False
        self.subject = 0
        self.eval_avg = True
        self.train_cond_stage_only = 1

        # diffusion sampling parameters
        self.num_samples = 5
        self.ddim_steps = 250
        self.HW = None
        # resume check util
        self.model_meta = None
        self.checkpoint_path = None